import pickle
import traceback
import sqlite3
import os
import time
import math

class ExpandedExperimentDatabase(object):
    def __init__(self, experiment_name, theta_b_vals, phi_b_vals, theta_m_vals, phi_m_vals, num_bits, save_characterization = True, save_debug = False,save_spectrums = False, measurement_number = None, verbose=True):
        self.experiment_name = experiment_name
        self.file_location = "./Outputs/" + experiment_name + "/"
        self.save_characterization = save_characterization
        self.save_debug = save_debug
        self.save_spectrums = save_spectrums

        self.state_size = int(math.floor(num_bits/8) + 1)
        self.num_bits = num_bits

        self.verbose = verbose

        # If folder doesn't exist, then create it.
        if not os.path.isdir(self.file_location):
            os.makedirs(self.file_location)

        if measurement_number is None:
            # get base file name so data isn't overwritten by accident
            fname_base = self.file_location+self.experiment_name+"_measurement"
            fname = fname_base + "_1"
            n = 1
            while os.path.isfile(fname+"_top.data"):
                fname = fname_base + "_" + str(n)
                n = n + 1
        else:
            fname_base = self.file_location+self.experiment_name+"_measurement_"
            fname = fname_base + str(measurement_number)
            if os.path.isfile(fname+"_top.data"):
                print "EXPERIMENT ALLREADY EXISTS, CHANGE NAME"
                return

        self.fname_base = fname

        self.theta_b_vals = theta_b_vals
        self.phi_b_vals = phi_b_vals
        self.theta_m_vals = theta_m_vals
        self.phi_m_vals = phi_m_vals

        # initialize counter
        self.start_time = time.time()

        # create a list of states
        self.states = []

        # create a dictionary cache for rapid access of phase data
        self.phase_settings = {}

        # create a dictionary cache for rapid access of measurement data
        # initialize with empty 2D array for rapid indexing
        self.measurement_data = {}

        # a dictionary of the measurement times calculated for each state with respect to start time
        # initialize with empty 2D array for rapid indexing
        self.measurement_time = {}

        # An ordered 2D list of generations and the states explored in them
        self.batch_data = []

        # a dictionary of the characterization values calculated for each state
        self.state_characterizations = {}

        # a dictionary of the times that state characterization values are calculated
        self.state_characterization_times = {}

        # create a list of debug_data
        self.debug_data = []

        self.spectrums = {}

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, tb):
        if exc_type is not None:
            traceback.print_exception(exc_type, exc_value, tb)
            print "CODE CRASHED: SAVING DATA"
        else:
            if self.verbose:
                print "SAVING DATA"
        # save data no matter what
        self.save()

    def convertStateToSQLState(self,state):
        return

    def addBatchOfStates(self,states):
        # initialize 2D data dictionarys
        for state in states:
            self.state_characterizations[state] = None
            for theta_b in self.theta_b_vals:
                for phi_b in self.phi_b_vals:
                    self.measurement_data[(state,theta_b,phi_b)] = [[None for theta_m in self.theta_m_vals] for phi_m in self.phi_m_vals]
                    self.measurement_time[(state,theta_b,phi_b)] = [[None for theta_m in self.theta_m_vals] for phi_m in self.phi_m_vals]
                    self.phase_settings[(state,theta_b,phi_b)] = [None for i in range(25)]
        self.states += states
        self.batch_data += [states]

    def save(self):
        # create phase settings table
        database = sqlite3.connect(self.fname_base+".db")
        create_table_cmd = '''CREATE TABLE phase_settings (state BINARY(''' + str(self.state_size)
        create_table_cmd = create_table_cmd + '''), beam_theta REAL, beam_phi REAL,
                        tile_1_phase REAL, tile_2_phase REAL, tile_3_phase REAL, tile_4_phase REAL, tile_5_phase REAL,
                        tile_6_phase REAL, tile_7_phase REAL, tile_8_phase REAL, tile_9_phase REAL, tile_10_phase REAL,
                        tile_11_phase REAL, tile_12_phase REAL, tile_13_phase REAL, tile_14_phase REAL, tile_15_phase REAL,
                        tile_16_phase REAL, tile_17_phase REAL, tile_18_phase REAL, tile_19_phase REAL, tile_20_phase REAL,
                        tile_21_phase REAL, tile_22_phase REAL, tile_23_phase REAL, tile_24_phase REAL, tile_25_phase REAL)'''
        database.execute(create_table_cmd)
        # save phase settings
        for state in self.states:
            state_bin = bin(state)
            save_data_cmd = "INSERT INTO phase_settings VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
            sql_list = []
            for theta_b in self.theta_b_vals:
                for phi_b in self.phi_b_vals:
                    data = [state_bin,theta_b,phi_b] + self.phase_settings[(state,theta_b,phi_b)]
                    if len(data) < 28:
                        data += [None]*(28-len(data))
                    sql_list += [data]
            database.executemany(save_data_cmd, sql_list)

        # create measurement data table
        create_table_cmd = '''CREATE TABLE measurement_data (state BINARY(''' + str(self.state_size)
        create_table_cmd = create_table_cmd + '''), beam_theta REAL, beam_phi REAL'''
        for phi_m in self.phi_m_vals:
            for theta_m in self.theta_m_vals:
                row_name = ", data_phi_" + str(phi_m) + "_theta_" + str(theta_m) + " REAL"
                create_table_cmd = create_table_cmd + row_name
        create_table_cmd = create_table_cmd + ")"
        create_table_cmd = create_table_cmd.replace("-","n")
        database.execute(create_table_cmd)
        # save measurment data
        for state in self.states:
            state_bin = bin(state)
            num_data_points = len(self.phi_m_vals)*len(self.theta_m_vals)
            save_data_cmd = "INSERT INTO measurement_data VALUES (?,?,?" + ",?"*num_data_points +")"
            sql_list = []
            for theta_b in self.theta_b_vals:
                for phi_b in self.phi_b_vals:
                    data = [state_bin,theta_b,phi_b]
                    data_table = self.measurement_data[(state,theta_b,phi_b)]
                    for phi_m_idx in range(len(self.phi_m_vals)):
                        data += data_table[phi_m_idx]
                    sql_list += [data]
            database.executemany(save_data_cmd, sql_list)

        # create measurement time table
        create_table_cmd = '''CREATE TABLE measurement_time_data (state  BINARY(''' + str(self.state_size)
        create_table_cmd = create_table_cmd + '''), beam_theta REAL, beam_phi REAL'''
        for phi_m in self.phi_m_vals:
            for theta_m in self.theta_m_vals:
                row_name = ", time_phi_" + str(phi_m) + "_theta_" + str(theta_m) + " REAL"
                create_table_cmd = create_table_cmd + row_name
        create_table_cmd = create_table_cmd + ")"
        create_table_cmd = create_table_cmd.replace("-","n")
        database.execute(create_table_cmd)
        # save measurment time data
        for state in self.states:
            state_bin = bin(state)
            num_data_points = len(self.phi_m_vals)*len(self.theta_m_vals)
            save_data_cmd = "INSERT INTO measurement_time_data VALUES (?,?,?" + ",?"*num_data_points +")"
            sql_list = []
            for theta_b in self.theta_b_vals:
                for phi_b in self.phi_b_vals:
                    data = [state_bin,theta_b,phi_b]
                    data_table = self.measurement_time[(state,theta_b,phi_b)]
                    for phi_m_idx in range(len(self.phi_m_vals)):
                        data += data_table[phi_m_idx]
                    sql_list += [data]
            database.executemany(save_data_cmd, sql_list)

        database.commit()
        database.close()

        # save experiment info to reopen
        with open( self.fname_base + "_top.data", "w") as f:
            f.write("File base name: " + str(self.fname_base)+ "\n")
            f.write("Beam Theta Values: " + str(self.theta_b_vals)+ "\n")
            f.write("Beam Phi Values: " + str(self.phi_b_vals)+ "\n")
            f.write("Measured Theta Values: " + str(self.theta_m_vals)+ "\n")
            f.write("Measured Phi Values: " + str(self.phi_m_vals)+ "\n")
            f.write("States: " + str(self.states)+ "\n")

        if self.save_characterization:
            # save optimization data
            with open( self.fname_base + "_optimization.data", "w") as f:
                f.write("File base name: " + str(self.fname_base)+ "\n")
                batch_number = 1
                for batch in self.batch_data:
                    batch_vals = []
                    for state in batch:
                        if state in self.state_characterizations.keys():
                            batch_vals += [self.state_characterizations[state]]
                    f.write("Batch: " + str(batch_number)+ "\n")
                    f.write("States: " + str(batch)+ "\n")
                    f.write("Characterization Values:" + str(batch_vals)+ "\n")
                    batch_number += 1

            # save time data
            with open( self.fname_base + "_times.data", "w") as f:
                f.write("File base name: " + str(self.fname_base)+ "\n")
                f.write("Measurement Start Time: " + str(self.start_time)+ "\n")
                batch_number = 1
                for batch in self.batch_data:
                    batch_vals = []
                    for state in batch:
                        if state in self.state_characterization_times.keys():
                            batch_vals += [self.state_characterization_times[state]]
                    f.write("Batch: " + str(batch_number)+ "\n")
                    f.write("States: " + str(batch)+ "\n")
                    f.write("Characterization Times:" + str(batch_vals)+ "\n")
                    batch_number += 1
        if self.save_debug:
            # save debug data
            with open( self.fname_base + "_debug.data", "w") as f:
                f.write("File base name: " + str(self.fname_base)+ "\n")
                f.write("Measurement Start Time: " + str(self.start_time)+ "\n")
                times = []
                temps = []
                supplys = []
                for data_point in self.debug_data:
                    times += [data_point[0]]
                    temps += [data_point[1]]
                    supplys += [data_point[2]]
                f.write("Times: " + str(times)+ "\n")
                f.write("Temps: " + str(temps)+ "\n")
                f.write("Supplys:" + str(supplys)+ "\n")

        if self.save_spectrums:
            # save spectrum data
            with open( self.fname_base + "_spectrums.data", "w") as f:
                times = []
                temps = []
                supplys = []
                f.write("States: " + str(self.states)+ "\n")
                for state in self.states:
                    f.write("Spectrum: " + str(self.spectrums[state])+ "\n")

    def enterBeamPhaseSettings(self,state,theta_b,phi_b,phase_settings):
        # put data in a dictionary cache for quick access
        self.phase_settings[(state,theta_b,phi_b)] = phase_settings

    # warning, returns pointer into database DO NOT MODIFY
    def getBeamPhaseSettings(self,state,theta_b,phi_b):
        return self.phase_settings[(state,theta_b,phi_b)]

    def enterPatternData(self,state,theta_b,phi_b,theta_m_idx,phi_m_idx,data):
        self.measurement_data[(state,theta_b,phi_b)][phi_m_idx][theta_m_idx] = data
        self.measurement_time[(state,theta_b,phi_b)][phi_m_idx][theta_m_idx] = time.time() - self.start_time

    # warning, returns pointer into database DO NOT MODIFY
    def getPatternData(self,state,theta_b,phi_b):
        return self.measurement_data[(state,theta_b,phi_b)]

    def saveCharacterizationData(self,state, characterization_value):
        self.state_characterizations[state] = characterization_value
        self.state_characterization_times[state]  = time.time() - self.start_time

    # warning, returns pointer into database DO NOT MODIFY
    def getCharacterizationData(self, state):
        return self.state_characterizations[state]

    def addDebugData(self, temp, supply):
        time_point = time.time() - self.start_time
        self.debug_data += [(time_point, temp, supply)]

    def addSpectrum(self,state,data):
        self.spectrums[state] = data
